{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "545f38ff",
   "metadata": {},
   "source": [
    "# 7-3，FM模型"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f06ffa1c",
   "metadata": {},
   "source": [
    "FM算法全称为因子分解机 (FactorizationMachine)。\n",
    "\n",
    "它是广告和推荐领域非常著名的算法，在线性回归模型上考虑了特征的二阶交互。\n",
    "\n",
    "适合捕捉大规模稀疏特征(类别特征)当中的特征交互。\n",
    "\n",
    "FM及其衍生的一些较有名的算法的简要介绍如下：\n",
    "\n",
    "* FM(FactorizationMachine)：在LR基础上用隐向量点积实现自动化特征二阶交叉，且交互项的计算复杂度是O(n)，效果显著好于LR，速度极快接近LR。\n",
    "\n",
    "* FFM(Field Aware FM): 在FM的基础上考虑对不同的特征域(Field，可以理解成特征的分组)使用不同的隐向量。效果好于FM，但参数量急剧增加，且预测性能急剧下降。\n",
    "\n",
    "* Bilinear-FFM: 双线性FFM。为了减少FFM的参数量，设计共享矩阵来代替针对不同Field的多个隐向量。效果接近FFM，但参数量大大减少，与FM相当。交互后添加LayerNormlization时效果和略好于FFM.\n",
    "\n",
    "* DeepFM: 使用FM模型代替DeepWide中的Wide部分，且FM部分的隐向量与Deep部分的Embedding向量是共享的。FM部分可以捕获二阶显式特征交叉，而Deep部分能够捕获高阶隐式特征组合和交叉。\n",
    "\n",
    "* FiBiNET: 使用SE注意力(Squeeze-and-Excitation)机制来捕获特征重要性，并且使用Bilinear-FFM来捕获二阶特征交互。\n",
    "\n",
    "参考文章：张俊林《FFM及DeepFFM模型在推荐系统的探索》https://zhuanlan.zhihu.com/p/67795161\n",
    "\n",
    "<br>\n",
    "\n",
    "<font color=\"red\">\n",
    " \n",
    "公众号 **算法美食屋** 回复关键词：**pytorch**， 获取本项目源码和所用数据集百度云盘下载链接。\n",
    "    \n",
    "</font> \n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4f3e9c86",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch \n",
    "print(\"torch.__version__=\"+torch.__version__) "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e3bef545",
   "metadata": {},
   "source": [
    "```\n",
    "torch.__version__=1.10.0\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "78bb862c",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "2740610c",
   "metadata": {},
   "source": [
    "##  一，FM原理解析\n",
    "\n",
    "FM模型的表达形式如下：\n",
    "\n",
    "$$y_{FM} = x_0 + \\sum_{i=1}^n \\omega_i x_i + \\sum_{i=1}^{n-1}\\sum_{j=i+1}^{n} <\\vec{v_i},\\vec{v_j}> x_i x_j$$\n",
    "\n",
    "其中 前两项与 线性回归一致。\n",
    "\n",
    "第三项为特征交互项。用隐向量的点积来计算交互项的系数。这样做比直接设定一个$n\\times n$的交互参数矩阵$W$的好处是减少了参数数量，参数数量从 $n^2$减少为 $n\\times k$，其中k为隐向量$v_i$的长度。\n",
    "\n",
    "从数学上，FM算法用一组向量的两两内积代替了交互参数矩阵$W$，等价于将对称矩阵W分解成如下形式$W=V^TV$，这也是为什么FM算法被叫做因子分解机。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c34c64fa",
   "metadata": {},
   "source": [
    "非常有意思的是，交互项的计算复杂度也可以由 $O(n^2)$ 降低为 $O(nk)$，这样FM前向推断的计算复杂度近似为线性复杂度。对于特征数量n非常大而稀疏的模型，计算起来毫无压力。\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "273f8473",
   "metadata": {},
   "source": [
    "交互项的简化计算类似于 $ab+ac+bc =\\frac{1}{2} ((a+b+c)^2-(a^2+b^2+c^2))$"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3f47f917",
   "metadata": {},
   "source": [
    "$$\\sum_{i=1}^{n-1}\\sum_{j=i+1}^{n} <\\vec{v_i},\\vec{v_j}> x_i x_j\n",
    "= \\frac{1}{2}(\\sum_{i=1}^{n}\\sum_{j=1}^{n} <\\vec{v_i},\\vec{v_j}> x_i x_j - \\sum_{i=1}^{n} <\\vec{v_i},\\vec{v_i}> x_i x_i)$$\n",
    "$$= \\frac{1}{2}(\\sum_{i=1}^{n}\\sum_{j=1}^{n} \\sum_{f=1}^{k} v_{if}v_{jf} x_i x_j - \\sum_{i=1}^{n} \\sum_{f=1}^{k} v_{if}v_{if} x_i x_i)$$\n",
    "\n",
    "$$= \\frac{1}{2}\\sum_{f=1}^{k}(\\sum_{i=1}^{n}\\sum_{j=1}^{n}  v_{if}v_{jf} x_i x_j - \\sum_{i=1}^{n}  v_{if}v_{if} x_i x_i)$$\n",
    "\n",
    "$$= \\frac{1}{2}\\sum_{f=1}^{k}((\\sum_{i=1}^{n}v_{if}x_i)^2  - \\sum_{i=1}^{n}  (v_{if} x_i)^2)$$\n",
    "\n",
    "可以看到交互项的计算复杂度已经变成 $O(nk)$ 了\n",
    "\n",
    "因此 FM的模型形式也可以改写成：\n",
    "\n",
    "$$y_{FM} = x_0 + \\sum_{i=1}^n \\omega_i x_i +\\frac{1}{2}\\sum_{f=1}^{k}((\\sum_{i=1}^{n}v_{if}x_i)^2  - \\sum_{i=1}^{n}  (v_{if} x_i)^2)$$"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4c353cfc",
   "metadata": {},
   "source": [
    "注意到 \n",
    "\n",
    "$$\\frac{\\partial{y_{FM}}}{\\partial{v_{if}}} = (\\sum_{j=1}^{n}v_{jf}x_j) x_i - v_{if}x_i^2$$\n",
    "$$= x_i((\\sum_{j=1}^{n}v_{jf}x_j)  - v_{if}x_i)$$\n",
    "\n",
    "可见，只要训练样本中存在不等于0的 $x_i$ ，就能够给隐向量$\\vec{v_{i}}$贡献梯度，从而学到有效的$\\vec{v_{i}}$表示。\n",
    "\n",
    "同理，只要训练样本中存在不等于0的 $x_j$ ，就能够给隐向量$\\vec{v_{j}}$贡献梯度，从而学到有效的$\\vec{v_{j}}$表示。\n",
    "\n",
    "然后，就可以计算出有意义的交互项的权重$<\\vec{v_{i}},\\vec{v_{j}}>$。\n",
    "\n",
    "这非常重要，这说明非零的交互项权重可以在训练样本中不存在 $x_i$和$x_j$同时不为0的样本的发生。\n",
    "\n",
    "这是FM面对稀疏特征具有很强泛化性的原因。\n",
    "\n",
    "考虑一个典型的给用户推荐商品的推荐场景中，用户所在城市特征和商品类目特征的交互。\n",
    "\n",
    "葫芦岛是一个小城市，渔网是一种小众商品。它们都是稀疏特征，绝大部分样本在这两个onehot位上的取值都是0.\n",
    "\n",
    "稀疏乘以稀疏更加稀疏，所以在训练样本中可能根本不存在葫芦岛城市的用户购买渔网这样的样本。\n",
    "\n",
    "但是只要训练样本中存在着葫芦岛的用户购买其它商品这样的样本，也存在其他城市用户购买渔网这样的样本，FM模型就可以给葫芦岛市的用户购买渔网的可能性作出一个估计，这个值可能不小，最后甚至会给葫芦岛的用户推荐渔网。\n",
    "\n",
    "这就是FM面对稀疏特征具有很强泛化性的一个例子。\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a78b3ae6",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "8a3c6263",
   "metadata": {},
   "source": [
    "## 二，Pytorch代码实现"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d007db0c",
   "metadata": {},
   "source": [
    "下面是FM模型的一个完整pytorch实现。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "db12d6bb",
   "metadata": {},
   "source": [
    "$$\\sum_{i=1}^{n-1}\\sum_{j=i+1}^{n} <\\vec{v_i},\\vec{v_j}> x_i x_j = \\sum_{i=1}^{n-1}\\sum_{j=i+1}^{n} <x_i\\vec{v_i},x_j\\vec{v_j}> $$\n",
    "\n",
    "注意的是，我们代码中的embedding向量或者线性层作用结果实际上是 $x_i\\vec{v_i}$ 的结果。这是许多读者包括我在学习FM时候感到困惑的一个地方。\n",
    "\n",
    "对于 离散特征，onehot编码后其 $x_i $ 总是等于1或者0，$x_i$不为0的那些项才会保留到结果中，此时$x_i$总是等于1，因此$x_i\\vec{v_i}$就等于其embedding向量。对于连续特征，通过一个不带偏置的Linear层作用，获取到的实际上就是 $x_i\\vec{v_i}$，包含了$x_i$因子。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9a530e37",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch \n",
    "from torch import nn\n",
    "from torch import nn,Tensor \n",
    "import torch.nn.functional as F \n",
    "\n",
    "class NumEmbedding(nn.Module):\n",
    "    \"\"\"\n",
    "    连续特征用linear层编码\n",
    "    输入shape: [batch_size,features_num(n), d_in], # d_in 通常是1\n",
    "    输出shape: [batch_size,features_num(n), d_out]\n",
    "    \"\"\"\n",
    "    \n",
    "    def __init__(self, n: int, d_in: int, d_out: int, bias: bool = False) -> None:\n",
    "        super().__init__()\n",
    "        self.weight = nn.Parameter(Tensor(n, d_in, d_out))\n",
    "        self.bias = nn.Parameter(Tensor(n, d_out)) if bias else None\n",
    "        with torch.no_grad():\n",
    "            for i in range(n):\n",
    "                layer = nn.Linear(d_in, d_out)\n",
    "                self.weight[i] = layer.weight.T\n",
    "                if self.bias is not None:\n",
    "                    self.bias[i] = layer.bias\n",
    "\n",
    "    def forward(self, x_num):\n",
    "        # x_num: batch_size, features_num, d_in\n",
    "        assert x_num.ndim == 3\n",
    "        #x = x_num[..., None] * self.weight[None]\n",
    "        #x = x.sum(-2)\n",
    "        x = torch.einsum(\"bfi,fij->bfj\",x_num,self.weight)\n",
    "        if self.bias is not None:\n",
    "            x = x + self.bias[None]\n",
    "        return x\n",
    "    \n",
    "class CatEmbedding(nn.Module):\n",
    "    \"\"\"\n",
    "    离散特征用Embedding层编码\n",
    "    输入shape: [batch_size,features_num], \n",
    "    输出shape: [batch_size,features_num, d_embed]\n",
    "    \"\"\"\n",
    "    def __init__(self, categories, d_embed):\n",
    "        super().__init__()\n",
    "        self.embedding = nn.Embedding(sum(categories), d_embed)\n",
    "        self.offsets = nn.Parameter(\n",
    "                torch.tensor([0] + categories[:-1]).cumsum(0),requires_grad=False)\n",
    "        \n",
    "        torch.nn.init.xavier_uniform_(self.embedding.weight.data)\n",
    "\n",
    "    def forward(self, x_cat):\n",
    "        \"\"\"\n",
    "        :param x_cat: Long tensor of size ``(batch_size, features_num)``\n",
    "        \"\"\"\n",
    "        x = x_cat + self.offsets[None]\n",
    "        return self.embedding(x) \n",
    "    \n",
    "class CatLinear(nn.Module):\n",
    "    \"\"\"\n",
    "    离散特征用Embedding实现线性层（等价于先F.onehot再nn.Linear()）\n",
    "    输入shape: [batch_size,features_num], \n",
    "    输出shape: [batch_size,d_out]\n",
    "    \"\"\"\n",
    "    def __init__(self, categories, d_out=1):\n",
    "        super().__init__()\n",
    "        self.fc = nn.Embedding(sum(categories), d_out)\n",
    "        self.bias = nn.Parameter(torch.zeros((d_out,)))\n",
    "        self.offsets = nn.Parameter(\n",
    "                torch.tensor([0] + categories[:-1]).cumsum(0),requires_grad=False)\n",
    "\n",
    "    def forward(self, x_cat):\n",
    "        \"\"\"\n",
    "        :param x: Long tensor of size ``(batch_size, features_num)``\n",
    "        \"\"\"\n",
    "        x = x_cat + self.offsets[None]\n",
    "        return torch.sum(self.fc(x), dim=1) + self.bias \n",
    "    \n",
    "    \n",
    "class FMLayer(nn.Module):\n",
    "    \"\"\"\n",
    "    FM交互项\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(self, reduce_sum=True):\n",
    "        super().__init__()\n",
    "        self.reduce_sum = reduce_sum\n",
    "\n",
    "    def forward(self, x): #注意：这里的x是公式中的 <v_i> * xi\n",
    "        \"\"\"\n",
    "        :param x: Float tensor of size ``(batch_size, num_features, k)``\n",
    "        \"\"\"\n",
    "        square_of_sum = torch.sum(x, dim=1) ** 2\n",
    "        sum_of_square = torch.sum(x ** 2, dim=1)\n",
    "        ix = square_of_sum - sum_of_square\n",
    "        if self.reduce_sum:\n",
    "            ix = torch.sum(ix, dim=1, keepdim=True)\n",
    "        return 0.5 * ix\n",
    "    \n",
    "class FM(nn.Module):\n",
    "    \"\"\"\n",
    "    完整FM模型。\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(self, d_numerical, categories=None, d_embed=4,\n",
    "                 n_classes = 1):\n",
    "        super().__init__()\n",
    "        if d_numerical is None:\n",
    "            d_numerical = 0\n",
    "        if categories is None:\n",
    "            categories = []\n",
    "        self.categories = categories\n",
    "        self.n_classes = n_classes\n",
    "        \n",
    "        self.num_linear = nn.Linear(d_numerical,n_classes) if d_numerical else None\n",
    "        self.cat_linear = CatLinear(categories,n_classes) if categories else None\n",
    "        \n",
    "        self.num_embedding = NumEmbedding(d_numerical,1,d_embed) if d_numerical else None\n",
    "        self.cat_embedding = CatEmbedding(categories, d_embed) if categories else None\n",
    "        \n",
    "        if n_classes==1:\n",
    "            self.fm = FMLayer(reduce_sum=True)\n",
    "            self.fm_linear = None\n",
    "        else:\n",
    "            assert n_classes>=2\n",
    "            self.fm = FMLayer(reduce_sum=False)\n",
    "            self.fm_linear = nn.Linear(d_embed,n_classes)\n",
    "\n",
    "    def forward(self, x):\n",
    "        \n",
    "        \"\"\"\n",
    "        x_num: numerical features\n",
    "        x_cat: category features\n",
    "        \"\"\"\n",
    "        x_num,x_cat = x\n",
    "        \n",
    "        #linear部分\n",
    "        x = 0.0\n",
    "        if self.num_linear:\n",
    "            x = x + self.num_linear(x_num) \n",
    "        if self.cat_linear:\n",
    "            x = x + self.cat_linear(x_cat)\n",
    "        \n",
    "        #交叉项部分\n",
    "        x_embedding = []\n",
    "        if self.num_embedding:\n",
    "            x_embedding.append(self.num_embedding(x_num[...,None]))\n",
    "        if self.cat_embedding:\n",
    "            x_embedding.append(self.cat_embedding(x_cat))\n",
    "        x_embedding = torch.cat(x_embedding,dim=1)\n",
    "        \n",
    "        if self.n_classes==1:\n",
    "            x = x + self.fm(x_embedding)\n",
    "            x = x.squeeze(-1)\n",
    "        else: \n",
    "            x = x + self.fm_linear(self.fm(x_embedding)) \n",
    "        return x\n",
    "    \n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3ff538b7",
   "metadata": {},
   "outputs": [],
   "source": [
    "##测试 NumEmbedding\n",
    "\n",
    "num_embedding = NumEmbedding(2,1,4)\n",
    "x_num = torch.randn(2,2)\n",
    "x_out = (num_embedding(x_num.unsqueeze(-1)))\n",
    "print(x_out.shape)        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ca9ef57f",
   "metadata": {},
   "outputs": [],
   "source": [
    "##测试 CatEmbedding\n",
    "\n",
    "cat_embedding = CatEmbedding(categories = [3,2,2],d_embed=4) \n",
    "x_cat = torch.randint(0,2,(2,3))\n",
    "x_out = cat_embedding(x_cat)\n",
    "print(x_cat.shape)\n",
    "print(x_out.shape)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bb503f2f",
   "metadata": {},
   "outputs": [],
   "source": [
    "##测试 CatLinear\n",
    "\n",
    "cat_linear = CatLinear(categories = [3,2,2],d_out=1) \n",
    "x_cat = torch.randint(0,2,(2,3))\n",
    "x_out = cat_linear(x_cat)\n",
    "print(x_cat.shape)\n",
    "print(x_out.shape)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b09022e1",
   "metadata": {},
   "outputs": [],
   "source": [
    "##测试 FMLayer\n",
    "\n",
    "fm_layer = FMLayer(reduce_sum=False)\n",
    "\n",
    "x = torch.randn(2,3,4)\n",
    "x_out = fm_layer(x)\n",
    "print(x_out.shape)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "050b8e23",
   "metadata": {
    "lines_to_next_cell": 2
   },
   "outputs": [],
   "source": [
    "##测试 FM\n",
    "\n",
    "fm = FM(d_numerical = 3, categories = [4,3,2],\n",
    "        d_embed = 4,n_classes = 2)\n",
    "self = fm \n",
    "x_num = torch.randn(2,3)\n",
    "x_cat = torch.randint(0,2,(2,3))\n",
    "fm((x_num,x_cat))\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "29e2b226",
   "metadata": {},
   "source": [
    "## 三，Cretio数据集完整范例"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2b6e89a2",
   "metadata": {},
   "source": [
    "Cretio数据集是一个经典的广告点击率CTR预测数据集。\n",
    "\n",
    "这个数据集的目标是通过用户特征和广告特征来预测某条广告是否会为用户点击。\n",
    "\n",
    "数据集有13维数值特征(I1至I13)和26维类别特征(C14至C39), 共39维特征, 特征中包含着许多缺失值。\n",
    "\n",
    "训练集4000万个样本，测试集600万个样本。数据集大小超过100G.\n",
    "\n",
    "此处使用的是采样100万个样本后的cretio_small数据集。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c2680b15",
   "metadata": {},
   "outputs": [],
   "source": [
    "#!pip install torchkeras"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4ada2b70",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np \n",
    "import pandas as pd \n",
    "import datetime \n",
    "\n",
    "from sklearn.model_selection import train_test_split \n",
    "\n",
    "import torch \n",
    "from torch import nn \n",
    "from torch.utils.data import Dataset,DataLoader  \n",
    "import torch.nn.functional as F \n",
    "import torchkeras \n",
    "\n",
    "def printlog(info):\n",
    "    nowtime = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')\n",
    "    print(\"\\n\"+\"==========\"*8 + \"%s\"%nowtime)\n",
    "    print(info+'...\\n\\n')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "21f5b451",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "fa3283ff",
   "metadata": {},
   "source": [
    "### 1，准备数据"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d72dcaeb",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.preprocessing import LabelEncoder,QuantileTransformer\n",
    "from sklearn.pipeline import Pipeline \n",
    "from sklearn.impute import SimpleImputer \n",
    "\n",
    "dfdata = pd.read_csv(\"./eat_pytorch_datasets/criteo_small.zip\",sep=\"\\t\",header=None)\n",
    "dfdata.columns = [\"label\"] + [\"I\"+str(x) for x in range(1,14)] + [\n",
    "    \"C\"+str(x) for x in range(14,40)]\n",
    "\n",
    "cat_cols = [x for x in dfdata.columns if x.startswith('C')]\n",
    "num_cols = [x for x in dfdata.columns if x.startswith('I')]\n",
    "num_pipe = Pipeline(steps = [('impute',SimpleImputer()),('quantile',QuantileTransformer())])\n",
    "\n",
    "for col in cat_cols:\n",
    "    dfdata[col]  = LabelEncoder().fit_transform(dfdata[col])\n",
    "\n",
    "dfdata[num_cols] = num_pipe.fit_transform(dfdata[num_cols])\n",
    "\n",
    "categories = [dfdata[col].max()+1 for col in cat_cols]\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ec01fe7b",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch \n",
    "from torch.utils.data import Dataset,DataLoader \n",
    "\n",
    "#DataFrame转换成torch数据集Dataset, 特征分割成X_num,X_cat方式\n",
    "class DfDataset(Dataset):\n",
    "    def __init__(self,df,\n",
    "                 label_col,\n",
    "                 num_features,\n",
    "                 cat_features,\n",
    "                 categories,\n",
    "                 is_training=True):\n",
    "        \n",
    "        self.X_num = torch.tensor(df[num_features].values).float() if num_features else None\n",
    "        self.X_cat = torch.tensor(df[cat_features].values).long() if cat_features else None\n",
    "        self.Y = torch.tensor(df[label_col].values).float() \n",
    "        self.categories = categories\n",
    "        self.is_training = is_training\n",
    "    \n",
    "    def __len__(self):\n",
    "        return len(self.Y)\n",
    "    \n",
    "    def __getitem__(self,index):\n",
    "        if self.is_training:\n",
    "            return ((self.X_num[index],self.X_cat[index]),self.Y[index])\n",
    "        else:\n",
    "            return (self.X_num[index],self.X_cat[index])\n",
    "    \n",
    "    def get_categories(self):\n",
    "        return self.categories\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fb9b516d",
   "metadata": {},
   "outputs": [],
   "source": [
    "dftrain_val,dftest = train_test_split(dfdata,test_size=0.2)\n",
    "dftrain,dfval = train_test_split(dftrain_val,test_size=0.2)\n",
    "\n",
    "ds_train = DfDataset(dftrain,label_col = \"label\",num_features = num_cols,cat_features = cat_cols,\n",
    "                    categories = categories, is_training=True)\n",
    "\n",
    "ds_val = DfDataset(dfval,label_col = \"label\",num_features = num_cols,cat_features = cat_cols,\n",
    "                    categories = categories, is_training=True)\n",
    "\n",
    "ds_test = DfDataset(dftest,label_col = \"label\",num_features = num_cols,cat_features = cat_cols,\n",
    "                    categories = categories, is_training=True)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cb40472b",
   "metadata": {
    "lines_to_next_cell": 2
   },
   "outputs": [],
   "source": [
    "dl_train = DataLoader(ds_train,batch_size = 2048,shuffle=True)\n",
    "dl_val = DataLoader(ds_val,batch_size = 2048,shuffle=False)\n",
    "dl_test = DataLoader(ds_test,batch_size = 2048,shuffle=False)\n",
    "\n",
    "for features,labels in dl_train:\n",
    "    break \n",
    "    "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e3fffe25",
   "metadata": {},
   "source": [
    "### 2，定义模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d6b54af6",
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_net():\n",
    "    net = FM(\n",
    "        d_numerical= ds_train.X_num.shape[1],\n",
    "        categories= ds_train.get_categories(),\n",
    "        d_embed = 8, \n",
    "        n_classes = 1\n",
    "    )\n",
    "    return net \n",
    "\n",
    "from torchkeras import summary\n",
    "\n",
    "net = create_net()\n",
    "summary(net,input_data=features);\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9446731b",
   "metadata": {},
   "source": [
    "```\n",
    "--------------------------------------------------------------------------\n",
    "Layer (type)                            Output Shape              Param #\n",
    "==========================================================================\n",
    "Linear-1                                     [-1, 1]                   14\n",
    "Embedding-2                              [-1, 26, 1]            1,296,709\n",
    "NumEmbedding-3                           [-1, 13, 8]                  104\n",
    "Embedding-4                              [-1, 26, 8]           10,373,672\n",
    "FMLayer-5                                    [-1, 1]                    0\n",
    "==========================================================================\n",
    "Total params: 11,670,499\n",
    "Trainable params: 11,670,499\n",
    "Non-trainable params: 0\n",
    "--------------------------------------------------------------------------\n",
    "Input size (MB): 0.000084\n",
    "Forward/backward pass size (MB): 0.002594\n",
    "Params size (MB): 44.519421\n",
    "Estimated Total Size (MB): 44.522099\n",
    "--------------------------------------------------------------------------\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "969c88de",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "9547e067",
   "metadata": {},
   "source": [
    "### 3，训练模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6cd40c3a",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e8cdc3b2",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os,sys,time\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import datetime \n",
    "from tqdm import tqdm \n",
    "\n",
    "import torch\n",
    "from torch import nn \n",
    "from accelerate import Accelerator\n",
    "from copy import deepcopy\n",
    "\n",
    "\n",
    "def printlog(info):\n",
    "    nowtime = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')\n",
    "    print(\"\\n\"+\"==========\"*8 + \"%s\"%nowtime)\n",
    "    print(str(info)+\"\\n\")\n",
    "    \n",
    "class StepRunner:\n",
    "    def __init__(self, net, loss_fn,stage = \"train\", metrics_dict = None, \n",
    "                 optimizer = None, lr_scheduler = None,\n",
    "                 accelerator = None\n",
    "                 ):\n",
    "        self.net,self.loss_fn,self.metrics_dict,self.stage = net,loss_fn,metrics_dict,stage\n",
    "        self.optimizer,self.lr_scheduler = optimizer,lr_scheduler\n",
    "        self.accelerator = accelerator\n",
    "    \n",
    "    def __call__(self, features, labels):\n",
    "        #loss\n",
    "        preds = self.net(features)\n",
    "        loss = self.loss_fn(preds,labels)\n",
    "\n",
    "        #backward()\n",
    "        if self.optimizer is not None and self.stage==\"train\":\n",
    "            if self.accelerator is  None:\n",
    "                loss.backward()\n",
    "            else:\n",
    "                self.accelerator.backward(loss)\n",
    "            self.optimizer.step()\n",
    "            if self.lr_scheduler is not None:\n",
    "                self.lr_scheduler.step()\n",
    "            self.optimizer.zero_grad()\n",
    "            \n",
    "        #metrics\n",
    "        step_metrics = {self.stage+\"_\"+name:metric_fn(preds, labels).item() \n",
    "                        for name,metric_fn in self.metrics_dict.items()}\n",
    "        return loss.item(),step_metrics\n",
    "    \n",
    "    \n",
    "class EpochRunner:\n",
    "    def __init__(self,steprunner):\n",
    "        self.steprunner = steprunner\n",
    "        self.stage = steprunner.stage\n",
    "        self.steprunner.net.train() if self.stage==\"train\" else self.steprunner.net.eval()\n",
    "        \n",
    "    def __call__(self,dataloader):\n",
    "        total_loss,step = 0,0\n",
    "        loop = tqdm(enumerate(dataloader), total =len(dataloader))\n",
    "        for i, batch in loop:\n",
    "            features,labels = batch\n",
    "            if self.stage==\"train\":\n",
    "                loss, step_metrics = self.steprunner(features,labels)\n",
    "            else:\n",
    "                with torch.no_grad():\n",
    "                    loss, step_metrics = self.steprunner(features,labels)\n",
    "                    \n",
    "            step_log = dict({self.stage+\"_loss\":loss},**step_metrics)\n",
    "\n",
    "            total_loss += loss\n",
    "            step+=1\n",
    "            if i!=len(dataloader)-1:\n",
    "                loop.set_postfix(**step_log)\n",
    "            else:\n",
    "                epoch_loss = total_loss/step\n",
    "                epoch_metrics = {self.stage+\"_\"+name:metric_fn.compute().item() \n",
    "                                 for name,metric_fn in self.steprunner.metrics_dict.items()}\n",
    "                epoch_log = dict({self.stage+\"_loss\":epoch_loss},**epoch_metrics)\n",
    "                loop.set_postfix(**epoch_log)\n",
    "\n",
    "                for name,metric_fn in self.steprunner.metrics_dict.items():\n",
    "                    metric_fn.reset()\n",
    "        return epoch_log\n",
    "\n",
    "class KerasModel(torch.nn.Module):\n",
    "    def __init__(self,net,loss_fn,metrics_dict=None,optimizer=None,lr_scheduler = None):\n",
    "        super().__init__()\n",
    "        self.accelerator = Accelerator()\n",
    "        self.history = {}\n",
    "        \n",
    "        self.net = net\n",
    "        self.loss_fn = loss_fn\n",
    "        self.metrics_dict = nn.ModuleDict(metrics_dict) \n",
    "        \n",
    "        self.optimizer = optimizer if optimizer is not None else torch.optim.Adam(\n",
    "            self.parameters(), lr=1e-2)\n",
    "        self.lr_scheduler = lr_scheduler\n",
    "        \n",
    "        self.net,self.loss_fn,self.metrics_dict,self.optimizer = self.accelerator.prepare(\n",
    "            self.net,self.loss_fn,self.metrics_dict,self.optimizer)\n",
    "\n",
    "    def forward(self, x):\n",
    "        if self.net:\n",
    "            return self.net.forward(x)\n",
    "        else:\n",
    "            raise NotImplementedError\n",
    "\n",
    "\n",
    "    def fit(self, train_data, val_data=None, epochs=10, ckpt_path='checkpoint.pt', \n",
    "            patience=5, monitor=\"val_loss\", mode=\"min\"):\n",
    "        \n",
    "        train_data = self.accelerator.prepare(train_data)\n",
    "        val_data = self.accelerator.prepare(val_data) if val_data else []\n",
    "\n",
    "        for epoch in range(1, epochs+1):\n",
    "            printlog(\"Epoch {0} / {1}\".format(epoch, epochs))\n",
    "            \n",
    "            # 1，train -------------------------------------------------  \n",
    "            train_step_runner = StepRunner(net = self.net,stage=\"train\",\n",
    "                    loss_fn = self.loss_fn,metrics_dict=deepcopy(self.metrics_dict),\n",
    "                    optimizer = self.optimizer, lr_scheduler = self.lr_scheduler,\n",
    "                    accelerator = self.accelerator)\n",
    "            train_epoch_runner = EpochRunner(train_step_runner)\n",
    "            train_metrics = train_epoch_runner(train_data)\n",
    "            \n",
    "            for name, metric in train_metrics.items():\n",
    "                self.history[name] = self.history.get(name, []) + [metric]\n",
    "\n",
    "            # 2，validate -------------------------------------------------\n",
    "            if val_data:\n",
    "                val_step_runner = StepRunner(net = self.net,stage=\"val\",\n",
    "                    loss_fn = self.loss_fn,metrics_dict=deepcopy(self.metrics_dict),\n",
    "                    accelerator = self.accelerator)\n",
    "                val_epoch_runner = EpochRunner(val_step_runner)\n",
    "                with torch.no_grad():\n",
    "                    val_metrics = val_epoch_runner(val_data)\n",
    "                val_metrics[\"epoch\"] = epoch\n",
    "                for name, metric in val_metrics.items():\n",
    "                    self.history[name] = self.history.get(name, []) + [metric]\n",
    "            \n",
    "            # 3，early-stopping -------------------------------------------------\n",
    "            arr_scores = self.history[monitor]\n",
    "            best_score_idx = np.argmax(arr_scores) if mode==\"max\" else np.argmin(arr_scores)\n",
    "            if best_score_idx==len(arr_scores)-1:\n",
    "                torch.save(self.net.state_dict(),ckpt_path)\n",
    "                print(\"<<<<<< reach best {0} : {1} >>>>>>\".format(monitor,\n",
    "                     arr_scores[best_score_idx]),file=sys.stderr)\n",
    "            if len(arr_scores)-best_score_idx>patience:\n",
    "                print(\"<<<<<< {} without improvement in {} epoch, early stopping >>>>>>\".format(\n",
    "                    monitor,patience),file=sys.stderr)\n",
    "                break \n",
    "                \n",
    "        self.net.load_state_dict(torch.load(ckpt_path))\n",
    "            \n",
    "        return pd.DataFrame(self.history)\n",
    "\n",
    "    @torch.no_grad()\n",
    "    def evaluate(self, val_data):\n",
    "        val_data = self.accelerator.prepare(val_data)\n",
    "        val_step_runner = StepRunner(net = self.net,stage=\"val\",\n",
    "                    loss_fn = self.loss_fn,metrics_dict=deepcopy(self.metrics_dict),\n",
    "                    accelerator = self.accelerator)\n",
    "        val_epoch_runner = EpochRunner(val_step_runner)\n",
    "        val_metrics = val_epoch_runner(val_data)\n",
    "        return val_metrics\n",
    "        \n",
    "       \n",
    "    @torch.no_grad()\n",
    "    def predict(self, dataloader):\n",
    "        dataloader = self.accelerator.prepare(dataloader)\n",
    "        result = torch.cat([self.forward(t[0]) for t in dataloader])\n",
    "        return result.data\n",
    "              "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "17c5c1d4",
   "metadata": {
    "lines_to_next_cell": 0
   },
   "outputs": [],
   "source": [
    "from torchkeras.metrics import AUC\n",
    "\n",
    "net = create_net()\n",
    "loss_fn = nn.BCEWithLogitsLoss()\n",
    "\n",
    "metrics_dict = {\"auc\":AUC()}\n",
    "optimizer = torch.optim.Adam(net.parameters(), lr=0.005, weight_decay=0.001) \n",
    "\n",
    "model = KerasModel(net,\n",
    "                   loss_fn = loss_fn,\n",
    "                   metrics_dict= metrics_dict,\n",
    "                   optimizer = optimizer\n",
    "                  )         \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c621ef79",
   "metadata": {},
   "outputs": [],
   "source": [
    "dfhistory = model.fit(train_data = dl_train,val_data = dl_val,\n",
    "    epochs=20,\n",
    "    ckpt_path='checkpoint.pt',\n",
    "    patience=3,\n",
    "    monitor='val_auc',\n",
    "    mode='max')\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e3438d09",
   "metadata": {},
   "source": [
    "```\n",
    "================================================================================2022-08-11 19:39:44\n",
    "Epoch 7 / 20\n",
    "\n",
    "100%|██████████| 313/313 [01:14<00:00,  4.18it/s, train_auc=0.768, train_loss=0.475]\n",
    "100%|██████████| 79/79 [00:03<00:00, 23.94it/s, val_auc=0.767, val_loss=0.477]\n",
    "<<<<<< reach best val_auc : 0.7665905952453613 >>>>>>\n",
    "\n",
    "================================================================================2022-08-11 19:41:02\n",
    "Epoch 8 / 20\n",
    "\n",
    "100%|██████████| 313/313 [01:13<00:00,  4.23it/s, train_auc=0.768, train_loss=0.475]\n",
    "100%|██████████| 79/79 [00:03<00:00, 23.92it/s, val_auc=0.767, val_loss=0.477]\n",
    "<<<<<< reach best val_auc : 0.7671190500259399 >>>>>>\n",
    "\n",
    "================================================================================2022-08-11 19:42:20\n",
    "Epoch 9 / 20\n",
    "\n",
    "100%|██████████| 313/313 [01:13<00:00,  4.25it/s, train_auc=0.769, train_loss=0.475]\n",
    "100%|██████████| 79/79 [00:03<00:00, 23.37it/s, val_auc=0.768, val_loss=0.476]\n",
    "<<<<<< reach best val_auc : 0.768292248249054 >>>>>>\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "679352f3",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "309f81a4",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "eb0b24c6",
   "metadata": {},
   "source": [
    "### 4，评估模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "540b7e7c",
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "%config InlineBackend.figure_format = 'svg'\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "def plot_metric(dfhistory, metric):\n",
    "    train_metrics = dfhistory[\"train_\"+metric]\n",
    "    val_metrics = dfhistory['val_'+metric]\n",
    "    epochs = range(1, len(train_metrics) + 1)\n",
    "    plt.plot(epochs, train_metrics, 'bo--')\n",
    "    plt.plot(epochs, val_metrics, 'ro-')\n",
    "    plt.title('Training and validation '+ metric)\n",
    "    plt.xlabel(\"Epochs\")\n",
    "    plt.ylabel(metric)\n",
    "    plt.legend([\"train_\"+metric, 'val_'+metric])\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "036913fe",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_metric(dfhistory,\"loss\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "703ea78e",
   "metadata": {},
   "source": [
    "![](https://tva1.sinaimg.cn/large/e6c9d24egy1h532tdrdvsj20f40a6dg3.jpg)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "92792236",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_metric(dfhistory,\"auc\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a0c00a66",
   "metadata": {},
   "source": [
    "![](https://tva1.sinaimg.cn/large/e6c9d24egy1h532uca1oij20f40a9aab.jpg)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "24a5e739",
   "metadata": {
    "lines_to_next_cell": 0
   },
   "outputs": [],
   "source": [
    "model.evaluate(dl_val)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bd74c29e",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "7fbbd7a3",
   "metadata": {},
   "source": [
    "### 5，使用模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b5720a7f",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.metrics import roc_auc_score\n",
    "model.eval()\n",
    "preds = F.sigmoid(model.predict(dl_val))\n",
    "labels = torch.cat([x[-1] for x in dl_val])\n",
    "\n",
    "val_auc = roc_auc_score(labels.numpy(),preds.numpy())\n",
    "print(val_auc)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8c71b32b",
   "metadata": {},
   "source": [
    "```\n",
    "0.768292257292055\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7d58e311",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "71ec3b4b",
   "metadata": {},
   "source": [
    "### 6，保存模型"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "30747dc0",
   "metadata": {},
   "source": [
    "模型最佳权重已经保存在 model.fit(ckpt_path) 传入的参数中了。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "55758fb9",
   "metadata": {},
   "outputs": [],
   "source": [
    "net_clone = create_net()\n",
    "net_clone.load_state_dict(torch.load(\"checkpoint.pt\"))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6a99c8a6",
   "metadata": {
    "lines_to_next_cell": 0
   },
   "outputs": [],
   "source": [
    "from sklearn.metrics import roc_auc_score\n",
    "preds = torch.cat([F.sigmoid(net_clone(x[0])).data for x in dl_val]) \n",
    "labels = torch.cat([x[-1] for x in dl_val])\n",
    "\n",
    "val_auc = roc_auc_score(labels.numpy(),preds.numpy())\n",
    "print(val_auc)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b62e72ed",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "b887ec5a",
   "metadata": {},
   "source": [
    "**如果本书对你有所帮助，想鼓励一下作者，记得给本项目加一颗星星star⭐️，并分享给你的朋友们喔😊!** \n",
    "\n",
    "如果对本书内容理解上有需要进一步和作者交流的地方，欢迎在公众号\"算法美食屋\"下留言。作者时间和精力有限，会酌情予以回复。\n",
    "\n",
    "也可以在公众号后台回复关键字：**加群**，加入读者交流群和大家讨论。\n",
    "\n",
    "![算法美食屋logo.png](https://tva1.sinaimg.cn/large/e6c9d24egy1h41m2zugguj20k00b9q46.jpg)"
   ]
  }
 ],
 "metadata": {
  "jupytext": {
   "cell_metadata_filter": "-all",
   "formats": "ipynb,md",
   "main_language": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
